{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[View the runnable example on GitHub](https://github.com/intel-analytics/BigDL/tree/main/python/nano/tutorial/notebook/training/pytorch-lightning/accelerate_pytorch_lightning_training_ipex.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Accelerate PyTorch Lightning Training using Intel® Extension for PyTorch*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`bigdl.nano.pytorch.Trainer` API extends PyTorch Lightning Trainer with multiple integrated optimizations. You can instantiate a BigDL-Nano `Trainer` with `use_ipex=True` to apply [Intel® Extension for PyTorch*](https://github.com/intel/intel-extension-for-pytorch) (also known as IPEX) for an extra performance boost on Intel hardware."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "source": [
    "To enable IPEX, you need to install BigDL-Nano for PyTorch first:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "!pip install --pre --upgrade bigdl-nano[pytorch] # install the nightly-built version\n",
    "!source bigdl-nano-init # set environment variables"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    ">\n",
    "> Before starting your PyTorch Lightning application, it is highly recommended to run `source bigdl-nano-init` to set several environment variables based on your current hardware. Empirically, these variables will bring big performance increase for most PyTorch Lightning applications on training workloads."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "source": [
    "> ⚠️ **Warning**\n",
    "> \n",
    "> For Jupyter Notebook users, we recommend to run the commands above, especially `source bigdl-nano-init` before jupyter kernel is started, or some of the optimizations may not take effect."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a self-defined `LightningModule` (based on a [ResNet-18 model](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html) pretrained on ImageNet dataset) and dataloaders to finetune the model on [OxfordIIITPet dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.OxfordIIITPet.html) as an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "# Define LightningModule and dataloader\n",
    "\n",
    "import torch\n",
    "from torchvision.models import resnet18\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "class MyLightningModule(pl.LightningModule):\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.model = resnet18(pretrained=True)\n",
    "        num_ftrs = self.model.fc.in_features\n",
    "        # here the size of each output sample is set to 37.\n",
    "        self.model.fc = torch.nn.Linear(num_ftrs, 37)\n",
    "        self.criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.model(x)\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        output = self.model(x)\n",
    "        loss = self.criterion(output, y)\n",
    "        self.log('train_loss', loss)\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        x, y = batch\n",
    "        output = self.forward(x)\n",
    "        loss = self.criterion(output, y)\n",
    "        pred = torch.argmax(output, dim=1)\n",
    "        acc = torch.sum(y == pred).item() / (len(y) * 1.0)\n",
    "        metrics = {'test_acc': acc, 'test_loss': loss}\n",
    "        self.log_dict(metrics)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        return torch.optim.SGD(self.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\n",
    "\n",
    "\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import OxfordIIITPet\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "\n",
    "def create_dataloaders():\n",
    "    train_transform = transforms.Compose([transforms.Resize(256),\n",
    "                                          transforms.RandomCrop(224),\n",
    "                                          transforms.RandomHorizontalFlip(),\n",
    "                                          transforms.ColorJitter(brightness=.5, hue=.3),\n",
    "                                          transforms.ToTensor(),\n",
    "                                          transforms.Normalize([0.485, 0.456, 0.406],\n",
    "                                                               [0.229, 0.224, 0.225])])\n",
    "    val_transform = transforms.Compose([transforms.Resize(256),\n",
    "                                        transforms.CenterCrop(224),\n",
    "                                        transforms.ToTensor(),\n",
    "                                        transforms.Normalize([0.485, 0.456, 0.406],\n",
    "                                                             [0.229, 0.224, 0.225])])\n",
    "\n",
    "    # apply data augmentation to the tarin_dataset\n",
    "    train_dataset = OxfordIIITPet(root=\"/tmp/data\", transform=train_transform, download=True)\n",
    "    val_dataset = OxfordIIITPet(root=\"/tmp/data\", transform=val_transform)\n",
    "\n",
    "    # obtain training indices that will be used for validation\n",
    "    indices = torch.randperm(len(train_dataset))\n",
    "    val_size = len(train_dataset) // 4\n",
    "    train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])\n",
    "    val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])\n",
    "\n",
    "    # prepare data loaders\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size=32)\n",
    "    val_dataloader = DataLoader(val_dataset, batch_size=32)\n",
    "\n",
    "    return train_dataloader, val_dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MyLightningModule()\n",
    "train_loader, val_loader = create_dataloaders()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; _The definition of_ `MyLightningModule` _and_ `create_dataloaders` _can be found in the_ [runnable example](https://github.com/intel-analytics/BigDL/tree/main/python/nano/tutorial/notebook/training/pytorch-lightning/accelerate_pytorch_lightning_training_ipex.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To use IPEX for better performance, you could simply **import BigDL-Nano Trainer, and set** `use_ipex` **to be** `True`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bigdl.nano.pytorch import Trainer\n",
    "\n",
    "trainer = Trainer(max_epochs=5, use_ipex=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You could then do the normal training (and evaluation) steps with the IPEX accelerated trainer:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.fit(model, train_dataloaders=train_loader)\n",
    "trainer.validate(model, dataloaders=val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📚 **Related Readings**\n",
    "> \n",
    "> - [How to install BigDL-Nano](https://bigdl.readthedocs.io/en/latest/doc/Nano/Overview/install.html)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.13 ('nano-pytorch': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.7.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "09344c7f3239fd422839751f876786d6b1a624c40f19af1b43cb2737f421c2b2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
